-
Notifications
You must be signed in to change notification settings - Fork 0
Allow user supplied overlap functions #48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite exciting! I left a couple of specific comments below. Here let me go over the callsite-specialization issue that appears below and may be something to keep in mind more generally.
Let's compare two implementations of a simple "depot" function:
julia> function dotrig1(x, funcname)
f = funcname == "sin" ? sin :
funcname == "cos" ? cos :
funcname == "tan" ? tan :
funcname == "sec" ? sec : error("$funcname not supported")
return f(x)
end
julia> function dotrig2(x, funcname)
funcname == "sin" && return sin(x)
funcname == "cos" && return cos(x)
funcname == "tan" && return tan(x)
funcname == "sec" && return sec(x)
error("$funcname not supported")
endNow compare:
using Cthulhu
@descend iswarn=true dotrig1(π/4, "sin")
@descend iswarn=true dotrig2(π/4, "sin")You'll note that dotrig1 is inferred to return Any whereas dotrig2 is inferred to return Float64. What's happening here is callsite specialization: for dotrig1, there's a single place in the code that makes the f(x) call, and so the compiler has to allow that call to be generic: it inserts code that says, "OK, what is f? Can I find a compiled method for that f and that ::typeof(x)? If not, compile it; then run it." There's a lot of "state" that might affect the return type of f(x) and at a certain point Julia's type-inference just gives up. (It has to: type-inference is subject to the halting problem, so it has to have heuristics that self-terminate.)
In contrast, with dotrig2, there are 4 separate call sites for f(x), each corresponding to a different f. This allows the compiler to specialize each one of those sites differently, and since it knows the f with certainty it can fully specialize each one, precompile all the calls, and thus execute the function just by jumping to a specific known-in-advance memory address that gets hardwired into the compiled code. You can't get more efficient than that (well, with inlining you might, but the good news Julia will even do that if appropriate).
A related issue: if you have a list of functions that you want to apply to a lot of different variables, then
for x in varlist
for f in funclist
mysum += f(x)
end
endwill be really bad, because there isn't a single call that can be predicted in advance: each and every one of those O(m*n) calls has to be individually analyzed ("what's f?"). In contrast,
for f in funclist
mysum += apply_to_list(f, xlist)
end
@noinline apply_to_list(f::F, xlist) where F = sum(f, xlist)forces Julia to specialize apply_to_list for each different f: Julia may not be able to predict which f will come out of flist, but once it figures that out it calls one of many different compiled versions of apply_to_list, each of which is specialized to a particular f and so is highly efficient for iterating over xlist and aggregating the output. In other words, this is O(m) rather than O(m*n) in its runtime-dispatch performance.
This is an example of the function barrier trick. The f::F ... where F would ordinarily do nothing, but Julia has special heuristics for function- and type-arguments that avoid specialization in some cases (the number of Real subtypes is presumably limited in practice, but the number of Function subtypes is effectively unbounded so to avoid creating an infinite amount of compiled code Julia just decides not to specialize some code). In this case you may want to disable those heuristics, and adding a type-parameter achieves that. The @noinline is probably unnecessary with modern Julia, but at one point in Julia's development it was important to prevent inlining from defeating your efforts. I'm old-school so I insert these as a precaution.
| `σᵣ` and `σₜ` represent the sizes of the rotation and translation uncertainty regions. | ||
| The `objective` should be a function that takes the squared distance between the means of two `IsotropicGaussian`s, the sum of their variances, and the product of their amplitudes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should spell out the exact arguments. From this I infer that the argument order is objective(Δμ, σ²sum, ϕprod), but best to be explicit.
Also, does objective need to satisfy certain requirements? E.g., does it have to be monotonic? (Lennard-Jones comes to mind.) Do you need to specialize any methods for your objective function? E.g., estimate_lower_bound(::typeof(lennardjones), Δμ, σ²sum, ϕprod). If there is an API that the user-supplied function needs to satisfy, it should be spelled out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right that Lennard-Jones wouldn't work here without some careful thought -- maybe handling the attractive and repulsive terms separately.
The assumption is that the objective needs to be monotonically decreasing with increasing Δμ² (which is absolutely important to be clear about).
| obj = !isdict ? objective : (haskey(objective, (key1,key2)) ? objective[(key1,key2)] : objective[(key2,key1)]) | ||
| lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If profiling reveals this line to be a bottleneck, here's one way that might help a tiny bit:
| obj = !isdict ? objective : (haskey(objective, (key1,key2)) ? objective[(key1,key2)] : objective[(key2,key1)]) | |
| lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...) | |
| if !isdict | |
| lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective, kwargs...) # this call *might* be inferrable | |
| else | |
| obj = get(objective, (key1, key2), nothing) | |
| if obj === nothing | |
| obj = get(objective, (key2, key1), nothing) | |
| end | |
| if obj !== nothing | |
| lb, ub = (lb, ub) .+ generic_bounds(mgmmx.gmms[key1], mgmmy.gmms[key2], R, T, σᵣ, σₜ, pσ, mpϕ[key1][key2]; objective = obj, kwargs...) # this call is not | |
| end |
The important part of this is the !isdict case, which gives Julia a chance to "pass down" knowledge of typeof(objective) from the input arguments to the call to generic_bounds which can be specialized for objective=objective. For the second one, if you know (or can compute) the return type of generic_bounds then you might want to add a type-annotation, e.g., generic_bounds(args...; kwargs...)::Tuple{T,T} so that lb, ub has known type even if Julia can't infer its way through the entire call.
The other part of the change is vastly less important, but exploits the fact that
haskey(dict, a) ? dict[a] : nothinginvolves looking up the key a twice, whereas
get(dict, a, nothing)only looks up the key a once. Note that while nothing is a conventional default, if nothing is in fact a legitimate user-supplied value in the dictionary, you can ensure there's no ambiguity about whether the key was present in the dictionary as follows:
struct NotFound end # a private type for internal use only
const notfound = NotFound()
x = get(dict, key, notfound)
if x !== notfound
...Then there's no way that notfound was retrieved from dict (or if it is, it's clearly the user's fault).
You probably have enough places in your code that might check both orders of the keys that it might be worth splitting the double-get block into a utility function.
| randtform = AffineMap(RotationVec(π*0.1rand(3)...), SVector{3}(0.1*rand(3)...)) | ||
|
|
||
| # allowing some fuzziness in the distance | ||
| relaxed_overlap(distsq, s, w) = gaussian_overlap(max(0, distsq - sign(w) * 0.5), s, w) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
Thanks for the feedback! I was worried about runtime dispatch issues, and the function barrier trick seems really helpful here. |
|
Another option that might be worth taking seriously would be stashing the interaction overlap functions in struct InteractionOverlap{FS,FH,FI,FP}
steric::FS
hydrophobic::FH
ionic::FI
polar::FP
endI'm not sure that's a good idea, but it certainly would make everything inferrable. But profiling is the real decider, here. "Strategic non-inferrability" can be a good thing and may not hurt your performance while also letting you simplify your code. Mostly it's about knowing what bottlenecks you have and what tricks you have at your disposal for fixing them. It's usually not worth fixing inference problems unless they are affecting performance. |
This adds a keyword argument,
objectiveto theoverlapfunction. A way to pass objective functions to through to be applied to bounds calculations (e.g.gogma_align->branchbound->generic_overlap-> passedobjective) has also been added.Further, objective functions for multi-GMMs can be supplied in a dictionary similar to the format for
interactions: